import cv2
import os
import numpy as np
import math
import sys
import argparse


class LayerInfo:
    def __init__(self, dir, name, size, stride, offset):
        self.name = name
        self.size = size
        self.stride = stride
        self.offset = offset
        self.filename = "Compare_" + name.replace("<", "_").replace(">", "_") + ".csv"
        self.reference = []
        self.compiled = []
        with open(os.path.join(dir, self.filename)) as csv:
            for line in csv:
                values = line.rstrip().split(",")
                try:
                    self.reference.append(float(values[0]))
                    self.compiled.append(float(values[1]))
                except:
                    pass  # ignore header row


class LayerComparison:

    def __init__(self):
        self.layers = None
        self.windowoffset = 0
        self.threshold = 0.1
        self.arg_parser = argparse.ArgumentParser(
            """Visualizes the report generated by the "debugcompiler" tool by showing side by side images
generated by each layer of the neural networks that were compared.""")
        self.arg_parser.add_argument("report", help="path to the report.md file")
        self.arg_parser.add_argument("--windowoffset", type=int, default=self.windowoffset,
                                     help="offset for the windows so you can move them to different monitor")
        self.arg_parser.add_argument("--threshold", type=float, default=self.threshold,
                                     help="threshold for comparing channels (draw red rectangles highlighting diff)")

    def parse_command_line(self, args):
        args = self.arg_parser.parse_args(args)
        self.layers = self.load(args.report)
        self.windowoffset = args.windowoffset
        self.threshold = args.threshold
        return True

    def load(self, filename):
        name = ""
        layers = []
        dir = os.path.dirname(filename)
        size = None
        stride = None
        offset = None
        with open(filename) as f:
            for line in f:
                if (line.startswith("## ")):
                    name = line[3:].rstrip()
                    stride = None

                parts = line.split(',')
                if (parts[0].startswith("Output size:")):
                    for part in parts:
                        nvalue = part.split(":")
                        if len(nvalue) == 2:
                            varname = nvalue[0].strip()
                            rvalue = nvalue[1].strip().split(" x ")
                            shape = list(map(int, rvalue))
                            if varname == "Output size":
                                size = shape
                            elif varname == "stride":
                                stride = shape
                            elif varname == "offset":
                                offset = shape

                    if size is not None:
                        print("loading layer ", name, " of size ", size, ", stride ", stride, ", offset", offset)
                        layers.append(LayerInfo(dir, name, size, stride, offset))
        return layers

    def tile_channels(self, img, stride):
        w = stride[0]
        h = stride[1]
        channels = stride[2]
        rows = 1
        cols = 1
        while channels > 1:
            s = math.sqrt(channels)
            if (s == int(s)):
                cols = cols * int(s)
                rows = rows * int(s)
                channels = 1
            elif ((channels % 25) == 0):
                cols = cols * 5
                rows = rows * 5
                channels = channels / 25
            elif ((channels % 10) == 0):
                cols = cols * 5
                rows = rows * 2
                channels = channels / 10
            elif ((channels % 4) == 0):
                cols = cols * 2
                rows = rows * 2
                channels = channels / 4
            else:
                cols = cols * int(channels)
                channels = 1
        c = 0
        result = np.zeros([h * rows, w * cols, 1])
        for i in range(cols):
            for j in range(rows):
                x = i * w
                y = j * h
                result[y:y + h, x:x + w] = img[:, :, c:c + 1]
                c = c + 1
        return result

    def compare_image(self, a, b):
        da = np.sum(a)
        db = np.sum(b)
        result = da - db
        if (result < 0):
            result = -result
        result = result / da
        return result

    def compare_tiles(self, a, b, ta, tb):
        stride = a.shape
        h = stride[0]
        w = stride[1]
        rows = int(ta.shape[0] / h)
        cols = int(ta.shape[1] / w)
        c = 0
        if (self.threshold > 0):
            for i in range(cols):
                for j in range(rows):
                    v = self.compare_image(a[:, :, c:c + 1], b[:, :, c:c + 1])
                    if (v > self.threshold):
                        print("comparing channel ", c, " found difference ", v)
                        x = i * w
                        y = j * h
                        cv2.rectangle(ta, (x, y), (x + w, y + w), (0, 0, 255), 1)
                        cv2.rectangle(tb, (x, y), (x + w, y + w), (0, 0, 255), 1)
                    c = c + 1

    def rgb_image(self, a):
        # scale array to range 0-1
        min_value = np.amin(a)
        max_value = np.amax(a)
        if min_value != max_value:
            scale = 255.0 / (max_value - min_value)
        else:
            scale = 1.0
        a = (a - min_value)
        if (min_value < max_value):
            a = a * scale
        gray = a.astype(np.uint8)
        return cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)

    def reshape(self, data, stride):
        try:
            return np.reshape(data, stride)
        except:
            # try again with our own calculation (assumes data is square)
            print("error: data doesn't match stride: %s" % (str(stride)))
            size = int(math.sqrt(len(data) / 3))
            try:
                return np.reshape(data, (size, size, 3))
            except:
                print("error: data is not square")
                data = np.zeros(stride[0] * stride[1] * stride[2])
                return np.reshape(data, stride)

    def show_layer(self, i):
        layer = self.layers[i]
        print("Showing Layer {} , size={}, stride={}, offset={}".format(
            layer.name, layer.size, layer.stride, layer.offset))
        stride = layer.stride
        name = "Reference Layer " + str(i) + ": " + layer.name
        a = np.reshape(layer.reference, stride)
        b = np.reshape(layer.compiled, stride)

        show_difference = True

        ta = self.tile_channels(a, stride)
        tb = self.tile_channels(b, stride)

        ta = self.rgb_image(ta)
        tb = self.rgb_image(tb)

        self.compare_tiles(a, b, ta, tb)

        cv2.imshow(name, ta)
        cv2.moveWindow(name, self.windowoffset, 0)

        if show_difference:
            name2 = "Difference Layer " + str(i) + ": " + layer.name
            cv2.imshow(name2, ta - tb)
        else:
            name2 = "Compiled Layer " + str(i) + ": " + layer.name
            cv2.imshow(name2, tb)

        y = 0
        x = 0
        if (ta.shape[1] > 2 * ta.shape[0]):
            y = ta.shape[0]
            if (y < 150):
                y = 150   # this is the minimum window size
            y = y + 10
        else:
            x = ta.shape[1]
            if (x < 316):
                x = 316   # this is the minimum window size
            x = x + 10
        cv2.moveWindow(name2, x + self.windowoffset, y)
        self.wait_for_escape()
        cv2.destroyWindow(name)
        cv2.destroyWindow(name2)

    def wait_for_escape(self):
        print("Press ESC to continue...")
        while True:
           if cv2.waitKey(1) & 0xFF == 27:
               break

    def show_layers(self):
        for i in range(len(self.layers)):
            self.show_layer(i)


if __name__ == '__main__':
    lc = LayerComparison()
    args = sys.argv
    args.pop(0)  # when an args list is passed to parse_args, the first argument (program name) needs to be dropped
    if (lc.parse_command_line(args)):
        lc.show_layers()
